{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This example demonstrates one possible way to cluster data sets that are too large to fit into memory using MDTraj and `scipy.cluster`. The idea for the algorithim is that we'll cluster every N-th frame directly, and then, considering the clusters fixed \"assign\" the remaining frames to the nearest cluster. It's not the most sophisticated algorithm, but it's a good demonstration of how MDTraj can be integrated with other data science tools."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "from __future__ import print_function\n",
    "import random\n",
    "from collections import defaultdict\n",
    "import mdtraj as md\n",
    "import numpy as np\n",
    "import scipy.cluster.hierarchy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "stride = 5\n",
    "subsampled = md.load('ala2.h5', stride=stride)\n",
    "print(subsampled)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Compute the pairwise RMSD between all of the frames. This requires\n",
    "N^2 memory, which is why we might need to stride."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "distances = np.empty((subsampled.n_frames, subsampled.n_frames))\n",
    "for i in range(subsampled.n_frames):\n",
    "    distances[i] = md.rmsd(subsampled, subsampled, i)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that we have the distances, we can use out favorite clustering\n",
    "algorithm. I like ward."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "n_clusters = 3\n",
    "linkage = scipy.cluster.hierarchy.ward(distances)\n",
    "labels = scipy.cluster.hierarchy.fcluster(linkage, t=n_clusters, criterion='maxclust')\n",
    "\n",
    "labels"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, we need to extract n_leaders random samples from each of the clusters.\n",
    "One way to do this is by building a map from each of the cluster labels\n",
    "to the list of the indices of the subsampled confs which belong to it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "mapping = defaultdict(lambda : [])\n",
    "for i, label in enumerate(labels):\n",
    "    mapping[label].append(i)\n",
    "\n",
    "mapping"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "Now we can iterate through the mapping and select n_leaders random\n",
    "samples from each list. As we select them, we'll extract the\n",
    "conformation and append it to a new trajectory called `leaders`, which\n",
    "will start empty."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "n_leaders_per_cluster = 2\n",
    "leaders = md.Trajectory(xyz=np.empty((0, subsampled.n_atoms, 3)),\n",
    "                        topology=subsampled.topology)\n",
    "leader_labels = []\n",
    "for label, indices in mapping.items():\n",
    "    leaders = leaders.join(subsampled[np.random.choice(indices, n_leaders_per_cluster)])\n",
    "    leader_labels.extend([label] * n_leaders_per_cluster)\n",
    "\n",
    "print(leaders)\n",
    "print(leader_labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now our `leaders` trajectory contains a set of representitive conformations\n",
    "for each cluster. Here comes the second pass of the two-pass clustering.\n",
    "Let's now consider these clusters as fixed objects and iterate through\n",
    "*every* frame in our data set, assigning each frame to the cluster\n",
    "it's closest to. We take the simple approach here of computing the distance\n",
    "from each frame to each leader and assigning the frame to the cluster whose\n",
    "leader is closest.\n",
    "\n",
    "Note that this whole algorithm never requires having the entire\n",
    "dataset in memory at once"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "labels = []\n",
    "for frame in md.iterload('ala2.h5', chunk=1):\n",
    "    labels.append(leader_labels[np.argmin(md.rmsd(leaders, frame, 0))])\n",
    "labels = np.array(labels)\n",
    "\n",
    "print(labels)\n",
    "print(labels.shape)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
